import torch

a = torch.randn((2, 3, 28, 28), dtype=torch.float32)
print(a.shape)
# a.reshape(-1, 28*28)
a = a.reshape(-1, 28*28)  # reshape not in-place
print(a.shape)

# torch.transpose(a, 0, 1)
a = torch.transpose(a, 0, 1)  # transpose not in-place
print(a.shape)
